Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched RoPE kernel #3095

Merged
merged 10 commits into from
Mar 13, 2024
Merged

Add batched RoPE kernel #3095

merged 10 commits into from
Mar 13, 2024

Conversation

tterrysun
Copy link
Contributor

@tterrysun tterrysun commented Feb 28, 2024

Problem: Currently we need to call rotary embedding kernel for each LoRA request, which makes it very inefficient to serve multiple LoRAs with different context length.

Solution: Add batched rotary embedding kernel. Followup PRs will pipe it through.

Testing: Batched kernel tests. Followup PRs will add e2e tests.

@@ -77,6 +77,48 @@ __global__ void rotary_embedding_kernel(
}
}

template<typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel(
Copy link
Collaborator

@pcmoritz pcmoritz Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kernel is almost exactly the same as rotary_embedding_kernel and you can make them the same by adding the const int64_t* __restrict__ cos_sin_cache_offsets (will be a null ptr if it is not set) argument there and then down below, doing

int64_t cos_sin_cache_offset = cos_sin_cache_offsets ? cos_sin_cache_offsets[token_idx] : 0;

Copy link
Contributor Author

@tterrysun tterrysun Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cos_sin_cache_offset is passed as a pointer, we don't have a good way to determine if it's empty without auxiliary flag, also we try to avoid runtime branching in kernel code for performance. agreed that these two kernels are pretty much the same so I refactored it to avoid too much code duplication.

@pcmoritz
Copy link
Collaborator

Do you have a (micro-)benchmark that shows the difference between batched and non-batched to justify the change?

@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 4, 2024

@tterrysun Can you add the micro benchmarks so we can measure the performance here? You can put it into benchmarks/kernels :)

@tterrysun
Copy link
Contributor Author

tterrysun commented Mar 6, 2024

Benchmarking command:
nsys profile -t nvtx,osrt --force-overwrite=true --stats=true --output=./rope_bm python benchmarks/kernels/benchmark_rope.py
Results:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range


 89.3           374257          1  374257.0  374257.0    374257    374257          0.0  PushPop  non-batched
 10.7            44781          1   44781.0   44781.0     44781     44781          0.0  PushPop  batched  

note that this is simulating serving 4 LoRAs, the more LoRAs served, the bigger the difference between single batch kernel & multiple non-batched kernels, majority of the difference should be from Python side. When serving a single LoRA, they should be equivalent

@tterrysun tterrysun marked this pull request as ready for review March 6, 2024 21:06
@tterrysun tterrysun requested a review from pcmoritz March 6, 2024 21:06
type=int,
choices=[64, 80, 96, 112, 128, 256],
default=128)
parser.add_argument("--rottery-dim",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_argument("--rottery-dim",
parser.add_argument("--rotary-dim",

seq_len=args.seq_len,
num_heads=args.num_heads,
head_size=args.head_size,
rotary_dim=args.rottery_dim,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rotary_dim=args.rottery_dim,
rotary_dim=args.rotary_dim,

@@ -158,27 +169,30 @@ def __init__(
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
scaling_factors: List[float],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also take in float by itself (coerce it into a list inside the init)?

@tterrysun tterrysun requested a review from Yard1 March 8, 2024 23:36
@Yard1 Yard1 merged commit 7e9bd08 into vllm-project:main Mar 13, 2024
23 checks passed
starmpcc pushed a commit to starmpcc/vllm that referenced this pull request Mar 14, 2024
@@ -107,7 +108,9 @@ def _forward(
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]

cos_sin = self.cos_sin_cache[positions]
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use positions.device rather than positions.get_device().
https://pytorch.org/docs/stable/generated/torch.Tensor.get_device.html

Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
@ByronHsu
Copy link

Hi I am wondering if this kernel currently used? I don't see changes in model code and not sure where the following PRs are.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants